from StillGAN.models import create_model
from StillGAN.data import create_dataset
from StillGAN.options.test_options import TestOptions
from StillGAN.util.visualizer import save_images
from StillGAN.util import html
from StillGAN.util.util import *
from StillGAN.models.networks import ResUNet
from torch import nn
import functools
from torch.autograd import Variable
from torchvision import transforms
import cv2
#某工具
class Identity(nn.Module):
    def forward(self, x):
        return x
def get_norm_layer(norm_type='instance'):
    """Return a normalization layer

    Parameters:
        norm_type (str) -- the name of the normalization layer: batch | instance | none

    For BatchNorm, we use learnable affine parameters and track running statistics (mean/stddev).
    For InstanceNorm, we do not use learnable affine parameters. We do not track running statistics.
    """
    if norm_type == 'batch':
        norm_layer = functools.partial(nn.BatchNorm2d, affine=True, track_running_stats=True)
    elif norm_type == 'instance':
        norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False)
    elif norm_type == 'none':
        def norm_layer(x):
            return Identity()
    else:
        raise NotImplementedError('normalization layer [%s] is not found' % norm_type)
    return norm_layer
#增强预处理
def pre_still(img):
    transform_list = []
    res = img
    osize = [512, 512]
    transform_list.append(transforms.Resize(osize, Image.BICUBIC))
    transform_list.append(transforms.RandomCrop(512))
    transform_list += [transforms.ToTensor()]
    transform_list += [transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
    trans = transforms.Compose(transform_list)
    res = trans(res)
    res = Variable(torch.unsqueeze(res, dim=0).float(), requires_grad=False)
    return res
#马博增强
def stillgan(img_path):
    current = Image.open(img_path)
    height, width = current.size
    pic = pre_still(current)
    model = ResUNet(3, 3, 64, norm_layer=get_norm_layer())
    net = torch.load('StillGAN/checkpoints/isee_csigan/120_net_G_A.pth', map_location=torch.device('cpu'))
    model.load_state_dict(net)
    runned = model(pic)
    image = tensor2im(runned)
    (r, g, b) = cv2.split(image)
    image = cv2.merge([b, g, r])
    fx = height
    fy = width
    image = cv2.resize(image, (fx, fy), interpolation=cv2.INTER_CUBIC)
    return image
if __name__=='__main__':
    #tmp/seg/raw.png输入图片路径

    img = stillgan('/Users/xfdw/Desktop/project/back/tmp/model1/raw/1.png')
    print(1)
    #img增强后图片
    cv2.imwrite('/Users/xfdw/Desktop/re/a.png', img)